{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dendron import *\n",
    "from dendron.actions.causal_lm_action import CausalLMActionConfig, CausalLMAction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = CausalLMActionConfig(load_in_4bit=True, \n",
    "                           max_new_tokens=64, \n",
    "                           do_sample=True, \n",
    "                           top_p=0.95, \n",
    "                           use_flash_attn_2=True,\n",
    "                           model_name = 'v1olet/v1olet_merged_dpo_7B')\n",
    "\n",
    "node = CausalLMAction(\"lm_action\", cfg)\n",
    "tree = BehaviorTree(\"causal-lm-tree\", node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"\"\"{system}\n",
    "### Instruction:\n",
    "{query}\n",
    "\n",
    "### Response:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def trunc(node):\n",
    "    out_str = node.blackboard[node.output_key]\n",
    "    m = re.search(r'### Response:(.*)', out_str, re.DOTALL)\n",
    "    node.blackboard[node.output_key] = m.group(1).strip()\n",
    "\n",
    "tree.root.add_post_tick(trunc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree.blackboard[\"in\"] = prompt.format(system=\"You are busy and don't have much time to be verbose.\", \n",
    "                                      query=\"The following is a headline from a newspaper: 'History-making SpaceX booster mostly destroyed in post-flight topple.' Is this about a spacecraft? Give a one-word yes or no.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "status = tree.tick_once()\n",
    "if status == NodeStatus.SUCCESS:\n",
    "     print(tree.blackboard[\"out\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tree.blackboard)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree.blackboard['out']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
